library(tidyverse)
library(data.table)
library(dtplyr)
library(tidytext)

# function to run cross-validation on out of sample prediction
# function takes:
# data (tibble), 
# character strings for outcome variable, column with text, column with docid, 
# character vector of terms to compound
# logical for whether to stem (default is true)
# numeric vector for number of n-grams to tokenize to (default is null, but we use in the final model)
# charater string for date column, if including dates (default is null, but we use in the final model)
# character strings for volume and weight variables, if including (default is null, we don't use in the final model)
# character string for type of variable importance to store in ranger (default is "none")
# number for minimum pct of docs a term has to appear in to be included in dfm (default is 1%)
# number for proportion of docs to be included in training set (default is 70%)
# number for number of folds in cross-validation (default is 15 -- five runs on three cores)
# seed  
text_to_pred_multifolds <- function(dat, outvar,
                                    text_col, doc_col,
                                    to_compound,
                                    stem = TRUE,
                                    ngrams = NULL,
                                    date_var = NULL,
                                    volume_var = NULL,
                                    weightvar = NULL,
                                    store_importance = "none",
                                    term_min = .01, 
                                    trainshare = .7,
                                    upsample = FALSE,
                                    n_folds = 15,
                                    seed = 11111){
  require(tidyverse)
  require(quanteda)
  require(distrom)
  require(textir)
  
  require(foreach)
  require(doMC)
  registerDoMC(cores = detectCores() - 1)
  
  set.seed(seed)
  
  # define index
  dat$index <- dat %>% pull(doc_col)
  
  message("Making corpus")
  corp <- quanteda::corpus(dat, text_field= text_col, docid_field= "index")
  
  message("Setting compounds...")
  
  compounds <- quanteda::phrase(
    c("white house","fox news","united states",
      "united nations",
      "social media","climate change","right wing",
      to_compound)
  )
  
  message("Tokenizing...")
  
  # tokenize/preprocess/remove extraneous twitter stuff
  tok <- quanteda::tokens(corp,
                          remove_punct=T,
                          remove_twitter=T, 
                          remove_numbers = T,
                          remove_url=T) %>%
    quanteda::tokens_remove("\\p{Z}", valuetype = "regex") %>%
    quanteda::tokens_remove("via", valuetype = "regex") %>%
    
    # remove stop words
    quanteda::tokens_remove(quanteda::stopwords(source = "snowball")) %>%
    
    # concatenate common bigrams
    quanteda::tokens_compound(pattern = compounds)
  
  if(stem == TRUE){
    message("Stemming...")
    tok <- tok %>% quanteda::tokens_wordstem()
  }
  if(!is.null(ngrams)){
    message(paste0("ngramming to n = ", ngrams, "..."))
    
    # ngram
    tok <- tok %>%  quanteda::tokens_ngrams(ngrams)
  }
  
  # make full document frequency matrix
  dfm.full <-  dfm(tok,  verbose=T,tolower = T)
  
  # discard very frequent/infrequent terms
  message("Trimming dfm...")
  
  dfm.lim <- dfm_trim(
    dfm.full, 
    min_docfreq = term_min, 
    max_docfreq = .5,
    docfreq_type = "prop")
  
  message(paste0("Retained ", ncol(dfm.lim), " tokens for analysis..."))
  
  if(!is.null(weightvar)){
    message("Weighting dfm...")
    for(i in rownames(dfm.lim)){
      weight <- dat[which(as.character(dat$index) == i),which(names(dat) == weightvar)] %>% as.numeric()
      
      dfm.lim[which(rownames(dfm.lim) == i),] <- dfm.lim[which(rownames(dfm.lim) == i),]*log(weight + 1)
    }
  }
  
  message("Prepping IV and DV...")
  
  # put back in matrix
  X <- as.matrix(dfm.lim)
  
  if(!is.null(date_var)){
    # make date variable and append to text features
    dates <- dat[match(rownames(X), dat$index), date_var]
    
    X <- cbind(X, dates)
  }
  
  if(!is.null(volume_var)){
    # make volume variable and append to text features
    volume <- dat[match(rownames(X), dat$index), volume_var]
    
    X <- cbind(X, volume)
  }
  
  # make outcome vector
  Y<- dat[match(rownames(X), dat$index), outvar] %>% pull(outvar)
  
  # discard empty docs
  drops <- which(rowSums(X)==0)
  
  if(length(drops) > 0){
    Y <- Y[-as.numeric(drops)]
    X <- X[-as.numeric(drops), ]
  }
  
  # make train/test splits
  pred_routine <- foreach::foreach(i = 1:n_folds) %dopar% {
    
    message(paste0("Splitting train and test sets for fold number", i,"..."))
    
    which_train <- sample(1:length(Y), floor(length(Y)*trainshare), replace = F)
    
    # define IVs and DV for training and test sets
    X_train <- X[which_train,]
    Y_train <- Y[which_train]
    
    X_test <- X[-which_train,]
    Y_test <- Y[-which_train]
    
    if(upsample == TRUE){
      message("Upsampling the training set...")
      
      x <- X_train
      y <- Y_train
      
      minClass <- min(table(y))
      
      whichclass <- as.numeric(ifelse(minClass == table(y)[1], 0, 1))
      
      howmax <- ifelse(whichclass == 0, table(y)[2], table(y)[1])
      howmin <- ifelse(whichclass == 0, table(y)[1], table(y)[2])
    
      indices_toup <- which(y == whichclass)
      
      extras <- sample(indices_toup, howmax - howmin, replace = T)
      
      indices_up <- c(1:nrow(X_train), extras)
      
      X_train <- X_train[indices_up,]
      Y_train <- Y_train[indices_up]
    }
    
    # bind training outcome/features
    trainbind <- cbind(Y_train, X_train)
    
    message(paste0("Modeling fold number", i,"..."))
    
    # model
    rf1 <- ranger::ranger(data= trainbind,
                          dependent.variable.name = "Y_train",
                          classification = TRUE,
                          probability = TRUE,
                          importance = store_importance,
                          verbose = FALSE)
    
    message(paste0("Predicting test set for fold number ", i,"..."))
    
    # generate predictions
    pt <- dat %>% filter(index %in% rownames(X_test)) %>%
      
      # take first column of predicted probabilities
      mutate(preds = predict(rf1, X_test)$predictions[,1]) %>%
      mutate(classpred = ifelse(preds >= .5, "R", "D")) %>%
      mutate(correct = (classpred == caucus),
             fold = i)
    
    # if the classes got flipped in RF prediction, flip back
    if(mean(pt$correct) < .5){
      pt <- pt %>% mutate(preds = 1-preds) %>%
        mutate(classpred = ifelse(preds >= .5, "R", "D")) %>%
        mutate(correct = (classpred == caucus))
    }
    
    # take measures of fit
    # overall accuracy, area under roc curve, and area under precision-recall curve for current fold
    aucroc <- PRROC::roc.curve(scores.class0 = as.numeric(pt$caucus == "R"),
                               weights.class0  = pt$preds)$auc
    
    auprc <- PRROC::pr.curve(scores.class0 = as.numeric(pt$caucus == "R"),
                             weights.class0  = pt$preds)$auc.integral
    
    outmat <- data.frame(acc = mean(pt$correct),
                         roc = aucroc,
                         prc = auprc,
                         fold = i)
    
    # if using variable importance, pull that out too
    if(!store_importance == "none"){
      imp <- rf1$variable.importance
    }else{
      imp <- NULL
    }
    
    return(list(fit_results = outmat,
                test_set_data = pt,
                importance = imp))
  }
  
  # stack measures of fit across folds
  outmat <- bind_rows(lapply(pred_routine, function(x){
    return(x$fit_results)
  }))
  
  test_sets <- bind_rows(lapply(pred_routine, function(x){
    return(x$test_set_data)
  }))
  
  # if storing importance, pull that out
  if(!store_importance == "none"){
    imp <- lapply(pred_routine, function(x){
      x$importance
    })
  }else{
    imp <- NULL
  }
  
  message("Done!")
  return(list(fit_results = outmat,
              preds = test_sets,
              importance = imp))
}

### this data starts out subsetted to tweets flagged by any of the sub-dictionaries
covid.tweets.nrt <- data.table::fread("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/covid_tweets_nrt_4120.csv")

covid.tweets.nrt[,week := lubridate::week(date)]
covid.tweets.nrt[,month := lubridate::month(date)]

# generalize hashtags starting with 'china' to 'china_hashtag'
covid.tweets.nrt[,lowtext := stringr::str_replace_all(lowtext, "\\s(#china[\\w]+)", " china_hashtag")]

covid.tweets.nrt <- as_tibble(covid.tweets.nrt)

zz <- file("cv_outfile.txt", open = "wt")
sink(zz, append = TRUE, type = "message")

full_cv <- text_to_pred_multifolds(dat = covid.tweets.nrt %>% filter(date < "2020-04-01"),
                                   outvar = "is_republican",
                                   text_col = "lowtext",
                                   doc_col = "index",
                                   to_compound = c("green new deal","nuclear power","cap and trade","clean coal", "gun violence","national security",
                                                   "tax cut","cut tax","cut taxes", "president trump", "american people", "trump administration",
                                                   "medicare for all","health care", "health insurance",
                                                   "single payer", "assault weapon","assault weapons","semi automatic","second amendment","brady bill",
                                                   "high capacity magainze","high capacity magazines","bump stock","bump stocks","background check",
                                                   "background checks","build the wall","wall funding","birthright citizenship","nuclear deal",
                                                   "pro life","pro choice","anti choice","born alive","partial birth","late term",
                                                   "house democrats","house republicans","senate democrats","senate republicans",
                                                   "majority leader","minority leader","town hall","donald trump","social security",
                                                   "law enforcement","preexisting conditions",
                                                   
                                                   # covid compounds
                                                   "world health organization","centers for disease control",
                                                   "supply chain","shelter in place","defense production act",
                                                   "personal protective equipment","laid off", "town hall",
                                                   "wish list","op ed"),
                                   term_min = 100/nrow(covid.tweets.nrt),
                                   date_var = "day_relative_1120",
                                   ngrams = c(1:3))
sink()

data.table::fwrite(full_cv$preds, file = "~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/output/fullcv_preds.csv")
data.table::fwrite(full_cv$fit_results, file = "~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/output/fullcv_cvfit.csv")

full_cv_upsample <- text_to_pred_multifolds(
                                   dat = covid.tweets.nrt %>% filter(date < "2020-04-01"),
                                   outvar = "is_republican",
                                   text_col = "lowtext",
                                   doc_col = "index",
                                   to_compound = c("green new deal","nuclear power","cap and trade","clean coal", "gun violence","national security",
                                                   "tax cut","cut tax","cut taxes", "president trump", "american people", "trump administration",
                                                   "medicare for all","health care", "health insurance",
                                                   "single payer", "assault weapon","assault weapons","semi automatic","second amendment","brady bill",
                                                   "high capacity magainze","high capacity magazines","bump stock","bump stocks","background check",
                                                   "background checks","build the wall","wall funding","birthright citizenship","nuclear deal",
                                                   "pro life","pro choice","anti choice","born alive","partial birth","late term",
                                                   "house democrats","house republicans","senate democrats","senate republicans",
                                                   "majority leader","minority leader","town hall","donald trump","social security",
                                                   "law enforcement","preexisting conditions",
                                                   
                                                   # covid compounds
                                                   "world health organization","centers for disease control",
                                                   "supply chain","shelter in place","defense production act",
                                                   "personal protective equipment","laid off", "town hall",
                                                   "wish list","op ed"),
                                   term_min = 100/nrow(covid.tweets.nrt),
                                   date_var = "day_relative_1120",
                                   ngrams = c(1:3),
                                   upsample = TRUE)
sink()

data.table::fwrite(full_cv_upsample$preds, file = "~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/output/fullcv_upsample_preds.csv")
data.table::fwrite(full_cv_upsample$fit_results, file = "~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/output/fullcv_upsample_cvfit.csv")
